#! /usr/bin/python3
"""Builds listener and forwarder."""
import datetime
import json
import os
import re
import typing

from cki_lib import certs
from cki_lib import logger
from cki_lib import messagequeue
from cki_lib import misc
import prometheus_client
import proton
from proton import handlers
from proton import reactor

LOGGER = logger.get_logger(__name__)

RABBITMQ_PUBLISH_EXCHANGE = os.environ.get('RABBITMQ_PUBLISH_EXCHANGE')
RABBITMQ_KEEPALIVE_S = misc.get_env_int('RABBITMQ_KEEPALIVE_S', 60)

RABBITMQ_CONNECTION = messagequeue.MessageQueue(keepalive_s=RABBITMQ_KEEPALIVE_S)

METRIC_MESSAGE_SENT = prometheus_client.Counter(
    'cki_message_forwarded',
    'Number of messages forwarded',
    ['routing_key']
)


def convert_unix_timestamp(unix_ts: float) -> str:
    """Convert a unix timestamp float to an isoformat string."""
    try:
        return datetime.datetime.fromtimestamp(unix_ts, tz=datetime.UTC).isoformat()
    except ValueError:
        return ''


def filtered_headers(headers: dict[str, typing.Any]) -> dict[str, typing.Any]:
    """Remove headers used by CKI and RabbitMQ."""
    return {
        k: v for k, v in headers.items()
        if not any(k.startswith(p) for p in ('x-', 'message-'))
    }


def message_send(message, routing_key, headers):
    """Send message to RabbitMQ queue."""
    LOGGER.debug('Sending message. exchange=%s routing_key=%s headers=%s message=%s',
                 RABBITMQ_PUBLISH_EXCHANGE, routing_key, headers, message)
    if not misc.is_production_or_staging():
        LOGGER.info('Devel environment, not forwarding messages.')
        return

    RABBITMQ_CONNECTION.send_message(
        message, routing_key, exchange=RABBITMQ_PUBLISH_EXCHANGE, headers=headers
    )
    METRIC_MESSAGE_SENT.labels(routing_key).inc()


class AMQP10Receiver(handlers.MessagingHandler):
    """Receive messages from AMQP 1.0."""

    def __init__(self, config):
        """Initialize all the values we need."""
        super().__init__(auto_accept=False)

        self.receiver_name = config['name']
        self.cert_path = config['cert_path']
        self.urls = config['receiver_urls']
        self.topics = self._parse_queues(config['message_topics'])

        certs.update_certificate_metrics(config['cert_path'])

    @staticmethod
    def _parse_queues(queues):
        """
        Parse a list of queues and turn them into topics if necessary.

        On devel deployments, in order not to consume production messages,
        queue names need to be turned into volatile topics.
        """
        if misc.is_production_or_staging():
            return queues

        queue_pattern = r'queue\:\/\/Consumer.[^.]+.[^.]+.(VirtualTopic.*)'
        topic_replacement = r'topic://\1'

        return [
            re.sub(queue_pattern, topic_replacement, queue)
            for queue in queues
        ]

    def on_start(self, event):
        """Connect to topics."""
        ssl = proton.SSLDomain(proton.SSLDomain.MODE_CLIENT)
        ssl.set_credentials(self.cert_path, self.cert_path, None)
        ssl.set_trusted_ca_db(os.getenv('REQUESTS_CA_BUNDLE'))
        ssl.set_peer_authentication(proton.SSLDomain.VERIFY_PEER)
        conn = event.container.connect(urls=self.urls, ssl_domain=ssl)

        for topic in self.topics:
            event.container.create_receiver(conn, source=topic)

    def on_message(self, event):
        """Handle a single message."""
        try:
            messagequeue.MessageQueue.measured_callback(
                self.callback, event
            )
        except Exception:  # pylint: disable=broad-except
            self.reject(event.delivery)
            raise
        self.accept(event.delivery)

    def callback(self, event):
        """Process a message."""
        headers = {
            **filtered_headers(event.message.properties),
            'message-type': 'amqp-bridge',
            'message-amqp-bridge-name': self.receiver_name,
            'message-amqp-bridge-protocol': 'amqp10',
            'message-date': convert_unix_timestamp(event.message.creation_time),
        }
        if isinstance(event.message.body, dict):
            LOGGER.debug('Not parsing complex message body')
            message = event.message.body
        else:
            LOGGER.debug('Parsing message body as JSON')
            try:
                body = event.message.body
                message = json.loads(body.tobytes() if isinstance(body, memoryview) else body)
            except (json.JSONDecodeError, TypeError):
                LOGGER.exception('Ignoring message with invalid body')
                return

        # Remove the prefix to make it independent of the routing key protocol.
        topic = re.sub(r'^(topic://VirtualTopic|queue://Consumer\.[^.]+\.[^.]+\.VirtualTopic)',
                       'VirtualTopic',
                       event.message.address)
        if topic == event.message.address:
            raise Exception(f'Unable to convert address {event.message.address} to routing key')
        routing_key = f'{self.receiver_name}.{topic}'

        message_send(message, routing_key, headers)

    def on_link_opened(self, event):
        # pylint: disable=no-self-use
        """Log successful link info."""
        LOGGER.info('Link opened to %s at address %s',
                    event.connection.hostname,
                    event.link.source.address)

    def on_link_error(self, event):
        """Log link errors."""
        LOGGER.error('Link error: %s: %s',
                     event.link.remote_condition.name,
                     event.link.remote_condition.description)
        LOGGER.info('Closing connection to %s', event.connection.hostname)
        event.connection.close()
        raise Exception('Link error occured!')

    def on_connection_error(self, event):
        """Log connection errors."""
        handlers.EndpointStateHandler.print_error(event.connection,
                                                  'connection')
        event.connection.close()
        raise Exception('Connection error occured!')


class AMQP091Receiver:
    """Receive messages from AMQP 0.9.1 server."""

    def __init__(self, config):
        """Initialize a message queue, but do not connect yet."""
        self.connection = messagequeue.MessageQueue(
            host=config['host'],
            port=config['port'],
            cafile=config['cafile'],
            certfile=config['certfile'],
            virtual_host=config['virtual_host'],
            dlx_retry=False
        )
        self.receiver_name = config['name']
        self.exchange = config['exchange']
        self.routing_keys = config['routing_keys']
        self.queue_name = config['queue_name']

        certs.update_certificate_metrics(config['certfile'])

    def callback(self, body=None, headers=None, routing_key=None, **_):
        """Handle an individual message."""
        headers = {
            **filtered_headers(headers or {}),
            'message-type': 'amqp-bridge',
            'message-amqp-bridge-name': self.receiver_name,
            'message-amqp-bridge-protocol': 'amqp091',
            'message-amqp091-topic': routing_key,
        }
        routing_key = f'{self.receiver_name}.{routing_key}'

        message_send(body, routing_key, headers)

    def receive_messages(self):
        """Endlessly receive messages."""
        self.connection.consume_messages(
            self.exchange, self.routing_keys,
            self.callback,
            queue_name=self.queue_name)


def process_amqp091(config):
    """Receive and forward AMQP 0.91 messages."""
    amqp_receiver = AMQP091Receiver(config)
    amqp_receiver.receive_messages()


def process_amqp10(config):
    """Receive and forward AMQP 1.0 messages."""
    amqp_receiver = reactor.Container(AMQP10Receiver(config))
    amqp_receiver.run()
